import os
import json

import torch
import torch_npu
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import swin_tiny_patch4_window7_224 as create_model

def preprocess_image(img_path, img_size):
    data_transform = transforms.Compose(
        [transforms.Resize(int(img_size * 1.14)),
         transforms.CenterCrop(img_size),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    assert os.path.exists(img_path), f"File: '{img_path}' does not exist."
    img = Image.open(img_path)
    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)
    return img

def predict_image(model, img, device, class_indict):
    with torch.no_grad():
        output = model(img.to(device))
        predict = torch.softmax(output, dim=1).cpu()
        return predict, class_indict

def main():
    device = "npu:4"  # NPU编号，可以根据实际情况修改
    torch.npu.set_device(torch.device(device))
    torch.npu.set_compile_mode(jit_compile=False)
    option = {"NPU_FUZZY_COMPILE_BLACKLIST": "Tril"}
    torch.npu.set_option(option)

    img_size = 224
    model = create_model(num_classes=5).npu().eval()
    model_weight_path = "./weights/model-9.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))

    json_path = './class_indices.json'
    assert os.path.exists(json_path), f"File: '{json_path}' does not exist."
    with open(json_path, "r") as f:
        class_indict = json.load(f)

    class_indices = {int(key): value for key, value in class_indict.items()}

    plt.ion()  # 开启交互模式
    while True:
        img_path = input("Enter the path of the image to predict (or 'exit' to quit): ")
        if img_path.lower() == 'exit':
            break
        img = preprocess_image(img_path, img_size)
        predict, class_indict_reversed = predict_image(model, img, device, class_indices)


        class_names = {i: class_indices[i] for i in range(len(class_indices))} 
        for i in range(predict.size(1)):  # predict.size(1)给出类别的数量
            prob = predict[0, i].item()  # 获取第i个类别的概率
            class_name = class_indices.get(i, "Unknown")  # 获取类别名称，如果索引不存在则返回"Unknown"
  
            print(f"class: {class_name:<15} prob: {prob:.5e}")  # 打印类别名称和概率，格式化输出

        # 显示图像和预测结果（这部分代码可以保留或修改，取决于你的需求）
        plt.clf()
        plt.imshow(Image.open(img_path))


    plt.ioff()  # 关闭交互模式

if __name__ == '__main__':
    main()